Source: Home0.jsx

import React, { Component } from 'react';
import "./Home0.css";
import GraphView from './ui/GraphView'
import DetailView from './ui/DetailView0'
import ParamView from './ui/ParamView'
import LoadView from './ui/LoadView'
import alertDialog from './ui/AlertDialog'
import { LucidBackend } from './scene/LucidBackend.js'
import Graph from './scene/Graph.js'
import { loadJSON, loadJSONFromLocalFile } from './util.js'
import { objectiveTypes, loadStates } from './LucidJS/src/optvis/renderer.js';
import { ttHelp } from './strings'
import { buildTour } from './TourBuilder'

import HelpIcon from '@material-ui/icons/Help';
import HelpOutlineIcon from '@material-ui/icons/HelpOutline';
import SchoolIcon from '@material-ui/icons/School';
import {
  IconButton, Fab, Tooltip,
  Modal, LinearProgress
} from "@material-ui/core";

export const scdryPrvMode = {
  COMPARE: 'compare',
  NEURON: 'neuron',
  PAINT: 'paint',
  ACT_ADJUST: 'act_adjust',
  STYLE: 'style'
}

export const learningRates = [0.00005, 0.0001, 0.0005, 0.001,
  0.005, 0.01, 0.05, 0.1, 0.5, 1, 5];


export default class Home extends Component {

  constructor(props) {
    super(props);

    this.modelList = require('./model/models.json');
    this.imnetClasses = require('./model/imnet_classes.json');
    this.mnistClasses = require('./model/mnist_classes.json');
    this.graph2Key = require('./model/graphName2Key.json');

    this.state = {
      graph: null,
      modelName: "",

      inputSize: 128,
      pyramidLayers: 4,
      decorrelate: true,
      baseImage: null,

      layer: 'mixed4a_pre_relu',
      learningRate: 6,
      classInd: 0,
      classList: [],
      pyrLayerWeight: 1,

      prvMode: scdryPrvMode.COMPARE,

      alertDialogMessage: null,

      showHelp: false,
      tour: null,

      showProgress: false,
      modelProgress: 0,
    };
    this.LB = new LucidBackend();
  }

  graphview() {
    return (
      <GraphView
        loadStatus={this.state.loadStatus}
        modelName={this.state.modelName}
        graph={this.state.graph}
        clickedNode={(node) => {

          this.LB.setLayer(node.name);
          this.LB.setFeatureMapLayer(node.name);
          this.setState(this.state);
        }}
      />);
  }

  detailview() {
    let currentInput = this.LB.hasCurrentInput() ?
      this.LB.getCurrentInput() :
      undefined;

    return (<DetailView
      showHelp={this.state.showHelp}
      inputSize={this.LB.getInputSize()}
      loadStatus={this.LB.getLoadStatus()}
      selectedLayer={this.LB.getLayer()}
      currentInput={currentInput}
      inputShape={this.LB.getModelInputShape()}
      lastInput={this.LB.getLastInput()}
      lastInputShape={this.LB.getLastInputShape()}
      activations={this.LB.getCurrentActivations()}
      activationShape={this.LB.getActivationShape()}
      detailActivations={this.LB.getCurrentActivations(this.LB.getChannel())}
      detailActivationStats={this.LB.getActivationStats(this.LB.getChannel())}
      channelNumber={this.LB.getChannelNumber()}
      selectedChannel={this.LB.getChannel()}
      channelChanged={(chInd) => {
        this.LB.setChannel(chInd);
        this.setState(this.state);
      }}
      selectedNeuron={this.LB.getNeuron()}
      neuronChanged={(x, y) => {
        this.LB.setNeuron(x, y);
        this.setState(this.state);
      }}
      activationMods={this.LB.getActivationModifications()}
      activationsModified={(mods) => {
        this.LB.setActivationModifications(mods);
      }}
      onReset={() => {
        this.LB.resetInput();
        this.setState(this.state);
      }}
      prvMode={this.state.prvMode}
      prvModeChanged={(mode) => {
        this.setState({ prvMode: mode });
      }}
      styleImage={this.LB.getStyleImage()}
      styleImageShape={this.LB.getStyleImageShape()}
      uploadedStyleImage={(styleImage) => {
        this.LB.setStyleImage(styleImage);
        this.setState(this.state);
      }}
      layerList={this.state.graph ?
        this.state.graph.getSortedLayerList() : []}
      styleLayers={this.LB.getStyleLayers()}
      styleLayerChanged={(layerName, type) => {
        const dict = this.LB.getStyleLayers();
        const list = dict[type];

        const currentIndex = list.indexOf(layerName);
        const newChecked = [...list];

        if (currentIndex === -1) {
          newChecked.push(layerName);
        } else {
          newChecked.splice(currentIndex, 1);
        }
        this.LB.setStyleLayers(newChecked, type);
        this.setState(this.state);
      }}
    />);
  }

  loadTitle() {
    if (this.LB.getLoadStatus() === loadStates.INITIAL) {
      return "No graph loaded";
    } else if (this.LB.getLoadStatus() === loadStates.LOADING) {
      return "Loading " + this.state.modelName;
    } else {
      return this.state.modelName;
    }
  }

  progressCb = (progress) => {
    this.setState({ modelProgress: progress });
  }

  loadModelFromFile = (files) => {
    let topoFile;
    const weightFiles = [];
    for (const file of files) {
      if (file.name.endsWith('.json')) {
        topoFile = file;
      } else {
        weightFiles.push(file);
      }
    }
    loadJSONFromLocalFile(topoFile, (modelJson) => {
      let modelName = modelJson.modelTopology.model_config.config.name;
      modelName = this.graph2Key[modelName];
      const model = this.modelList[modelName];

      this.setStartModelLoadingState(model, modelJson);
    });

    this.LB.loadModelFromFile(topoFile, weightFiles, this.progressCb).then(() => {
      this.setFinishModelLoadingState();
    });
  }

  loadModel = (model) => {
    const modelPath = process.env.PUBLIC_URL + '/' + model.path;
    loadJSON(modelPath, (modelJson) => {
      this.setStartModelLoadingState(model, modelJson);
    });

    this.LB.loadModel(modelPath, this.progressCb).then(() => {
      this.setFinishModelLoadingState();
    });
  }

  setFinishModelLoadingState = () => {
    const ip = this.LB.getInputParams();
    this.setState({
      inputSize: ip.inputSize,
      pyramidLayers: ip.pyramidLayers,
      decorrelate: ip.decorrelate,
      baseImage: ip.baseImage,
      showProgress: false,
    });
    this.LB.setLearningRate(learningRates[this.state.learningRate]);
  }

  setStartModelLoadingState = (model, modelJson) => {
    const classList = this.getClasslist(model.classlist);

    this.setState({
      modelName: model.name,
      classList: classList,
      showProgress: true,
      modelProgress: 0,
    });

    if ("defaultStyleLayers" in this.modelList[model.name]) {
      const defaultStyleLayers =
        this.modelList[model.name]["defaultStyleLayers"];
      this.LB.setStyleLayers(defaultStyleLayers);
    }
    const graph = new Graph(modelJson);
    const distDict = graph.getLayoutByInputDist();
    this.setState({ graph: graph });
  }

  /**
   * Renders this react component
   * @returns {*} the components contents
   */
  render() {
    return (<div style={{ width: "100%", height: "100%", top: 0, left: 0 }}>
      <div style={{
        width: "100%", height: "100%", display: 'flex',
        flexDirection: 'row', justifyContent: "center", alignItems: "center"
      }}
        className="noScroll" >
        <div style={{
          flex: '0 1 auto', width: "40%",
          height: "100%", borderRight: '1px solid darkgray',
          display: "flex", flexDirection: "column"
        }}>
          <div style={{
            flex: '0 1 120px', width: "100%"
          }}>
            <LoadView
              models={this.modelList}
              title={this.loadTitle()}
              canOptimize={this.LB.canOptimize()}
              isOptimizing={this.LB.isOptimizing()}
              onLoadModel={this.loadModel}
              onLoadModelFromFile={this.loadModelFromFile}
              onOptimize={() => {
                const validationMessage = this.LB.validateOptimizationInput();
                if (validationMessage) {
                  this.setState({ alertDialogMessage: validationMessage });
                } else {
                  this.LB.startOptimization(20000, (stopped) => {
                    if (!stopped) {
                      this.setState({
                        loadStatus: loadStates.OPTIMIZING,
                      });
                    } else {
                      this.setState({ loadStatus: loadStates.LOADED });
                    }
                  });
                }
              }}
              stopOptimization={() => {
                this.LB.stopOptimization();
                this.setState({ loadStatus: loadStates.LOADED });
              }}
              showHelp={this.state.showHelp}
            />
          </div>
          <div style={{
            flex: '1 1 auto', width: "100%", height: "50%"
          }}>
            {this.graphview()}
          </div>
        </div>
        <div style={{
          display: 'flex', flexDirection: 'column',
          flex: '1 1 auto', width: "60%", height: '100%'
        }}>
          <div style={{
            flex: '0 1 auto', width: "100%", height: '190px',
            borderBottom: '1px solid darkgray'
          }}>
            <ParamView
              showHelp={this.state.showHelp}
              loadStatus={this.state.loadStatus}
              classList={this.state.classList}
              classInd={this.state.classInd}
              onApplyInputParams={() => {
                if (!this.LB.hasModel()) {
                  this.setState({
                    alertDialogMessage: [
                      "Can't apply input params.", 'Please load model first!'
                    ]
                  })
                  return;
                }
                const inputParams = {
                  inputSize: this.state.inputSize,
                  pyramidLayers: this.state.pyramidLayers,
                  decorrelate: this.state.decorrelate,
                  baseImage: this.state.baseImage,
                }
                this.LB.setInputParams(inputParams);
                this.setState(this.state);
              }}
              inputSize={this.state.inputSize}
              pyramidLayers={this.state.pyramidLayers}
              decorrelate={this.state.decorrelate}
              baseImage={this.state.baseImage}
              objectiveType={this.LB.getObjectiveType()}
              jitter={this.LB.getJitter()}
              negative={this.LB.getNegative()}
              learningRate={this.state.learningRate}
              learningRates={learningRates}
              pyrLayerWeight={this.state.pyrLayerWeight}
              handleInputChange={(stateChange) => {
                this.setState(stateChange);
              }
              }
              changedObjective={(newObjective) => {
                this.LB.setObjectiveType(newObjective);
                let prvMode;
                if (newObjective === objectiveTypes.NEURON) {
                  prvMode = scdryPrvMode.NEURON;
                } else if (newObjective === objectiveTypes.ACT_ADJUST) {
                  prvMode = scdryPrvMode.ACT_ADJUST;
                } else if (newObjective === objectiveTypes.STYLE) {
                  prvMode = scdryPrvMode.STYLE;
                }
                if (prvMode) {
                  this.setState({ prvMode: prvMode });
                } else {
                  this.setState(this.state);
                }
              }
              }
              changedClassInd={(newClassInd) => {
                this.LB.setClass(newClassInd);
                this.setState({
                  classInd: newClassInd,
                });
              }}
              changedJitter={(newJitter) => {
                this.LB.setJitter(newJitter);
                this.setState(this.state);
              }
              }
              changedNegative={(newNegative) => {
                this.LB.setNegative(newNegative);
                this.setState(this.state);
              }
              }
              changedLearningRate={(newLearningRate) => {
                this.LB.setLearningRate(learningRates[newLearningRate]);
                this.setState({ learningRate: newLearningRate });
              }}
              changedPyrLayerWeight={(newPyrLayerWeight) => {
                this.LB.setClassObjFrequencyLevelWeights(newPyrLayerWeight);
                this.setState({pyrLayerWeight: newPyrLayerWeight});
              }}
            />
          </div>
          <div style={{ flex: '1 1 auto', width: "100%", height: "50%" }}>
            {this.detailview()}
          </div>
        </div>
      </div>
      {alertDialog(this.state.alertDialogMessage,
        () => {
          this.setState({ alertDialogMessage: null });
        })}
      <Tooltip title={ttHelp}>
        <Fab
          color="primary"
          style={{
            position: "fixed",
            left: "10px",
            bottom: "10px",
            zIndex: 10
          }}
          onClick={() => {
            this.setState({ showHelp: !this.state.showHelp });
          }}>
          {this.state.showHelp ?
            (<HelpIcon />) : (<HelpOutlineIcon />)}
        </Fab>
      </Tooltip>
      <Tooltip title="Intro tour">
        <Fab
          color="primary"
          style={{
            position: "absolute",
            left: "10px",
            bottom: "80px",
            zIndex: 10
          }}
          onClick={this.startTour}>
          <SchoolIcon />
        </Fab>
      </Tooltip>
      <Modal open={this.state.showProgress}>
        <div className="modelProgressBar">
          <h2>Loading model...</h2>
          <LinearProgress
            variant="determinate"
            value={this.state.modelProgress * 100} />
        </div>
      </Modal>
    </div>
    )

  }

  startTour = () => {
    let tour;
    if (this.state.tour && this.state.tour.isActive()) {
      this.state.tour.hide();
      this.setState({ tour: null });
    } else {
      tour = buildTour();
      this.setState({ tour: tour });
      tour.start();
    }
    console.log(tour);
  }

  getClasslist(classListName) {
    if (classListName === 'imnet') {
      return this.imnetClasses;
    } else if (classListName === 'mnist') {
      return this.mnistClasses;
    }
  }
}